Initial commit to pass scale as Tensor for multi_tensor_scale op#2594
Initial commit to pass scale as Tensor for multi_tensor_scale op#2594vasunvidia wants to merge 4 commits intoNVIDIA:mainfrom
Conversation
b2a5ae5 to
4081afc
Compare
Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…sed but not actually enabled Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
…s is passed but not actually enabled" This reverts commit 74a9bcc. Signed-off-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
4081afc to
cfd4370
Compare
Greptile OverviewGreptile Summaryadded tensor-based variant of Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Python as Python Code
participant PyBind as PyTorch Binding
participant Wrapper as C++ Wrapper
participant CUDA as CUDA Kernel
Python->>PyBind: multi_tensor_scale_tensor(chunk_size, noop_flag, tensor_lists, scale)
PyBind->>Wrapper: multi_tensor_scale_tensor_cuda(...)
Note over Wrapper: Convert PyTorch tensors<br/>to TE tensors
Wrapper->>CUDA: nvte_multi_tensor_scale_tensor_cuda(scale as NVTETensor)
Note over CUDA: Extract scale_tensor->data.dptr<br/>Cast to float*
CUDA->>CUDA: multi_tensor_scale_tensor_cuda(scale as float*)
Note over CUDA: Launch ScalePtrFunctor kernel<br/>with scale pointer
CUDA->>CUDA: ScalePtrFunctor reads *scale_ptr<br/>Applies scaling to tensors
CUDA-->>Python: Return scaled tensors
|
| auto scale_cu = makeTransformerEngineTensor(scale); | ||
| auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = | ||
| makeTransformerEngineTensorList(tensor_lists); | ||
| std::cout << "multi_tensor_scale_cuda TENSOR\n"; |
There was a problem hiding this comment.
debug print statement left in production code
| std::cout << "multi_tensor_scale_cuda TENSOR\n"; |
| Tensor *scale_tensor = convertNVTETensorCheck(scale); | ||
| multi_tensor_scale::multi_tensor_scale_tensor_cuda( | ||
| chunk_size, *convertNVTETensorCheck(noop_flag), | ||
| convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), | ||
| reinterpret_cast<float *>(scale_tensor->data.dptr), stream); |
There was a problem hiding this comment.
validate scale tensor dtype and size before casting. should check that:
- dtype is
kFloat32 numel() == 1(single element tensor)
Tensor *scale_tensor = convertNVTETensorCheck(scale);
NVTE_CHECK(scale_tensor->dtype() == DType::kFloat32,
"scale tensor must be float32, got ", to_string(scale_tensor->dtype()));
NVTE_CHECK(scale_tensor->numel() == 1,
"scale tensor must have exactly 1 element, got ", scale_tensor->numel());
multi_tensor_scale::multi_tensor_scale_tensor_cuda(
| // I'd like this kernel to propagate infs/nans. | ||
| // if(*noop_gmem == 1) | ||
| // return; | ||
| float scale = *scale_ptr; |
There was a problem hiding this comment.
all threads in a block read from the same device memory location without synchronization - could cause redundant memory traffic but functionally correct. consider caching in shared memory or using __ldg() for read-only cache optimization
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: